LSTM

一种常用的 循环神经网络(RNN) 模块,用于处理具有时序依赖特征的数据(如语音、文本、时间序列等)。每个时间步的公式化描述如下。

\[\begin{split}\begin{aligned} i_t &= \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) && \text{(输入门)} \\[6pt] f_t &= \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) && \text{(遗忘门)} \\[6pt] g_t &= \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) && \text{(候选状态)} \\[6pt] o_t &= \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) && \text{(输出门)} \\[6pt] c_t &= f_t \odot c_{t-1} + i_t \odot g_t && \text{(细胞状态更新)} \\[6pt] h_t &= o_t \odot \tanh(c_t) && \text{(隐藏状态更新)} \end{aligned}\end{split}\]
  • \(x_t\) : 当前时间步输入向量

  • \(h_{t-1}\) : 上一时间步的隐藏状态

  • \(c_{t-1}\) : 上一时间步的细胞状态

  • \(i_t, f_t, g_t, o_t\) : 四个门(输入门、遗忘门、候选门、输出门)

  • \(W_*\) : 对应的权重矩阵

  • \(b_*\) : 偏置项

  • \(\sigma(\cdot)\) : Sigmoid 函数

  • \(\odot\) : 元素乘

输入:
  • input - 输入序列数据,形状为 \((seq_len, batch, input_size)\),即每个时间步的输入特征。

  • weight_i - 输入到各门 \((input、forget、cell、output)\) 的权重矩阵,大小为 4 * hidden_size * input_size。

  • weight_h - 上一隐藏状态到各门的权重矩阵,大小为 \(4 * hidden_size * hidden_size\)

  • input_bias - 输入部分的偏置项,对应 4 个门的偏置。

  • state_bias - 隐藏状态部分的偏置项(也是 \(4 * hidden_size\)),与 input_bias 一起求和形成总偏置。

  • hidden_state - 当前批次初始隐藏状态输入( \(h₀\) ),执行后更新为最后时刻的隐藏状态输出( \(hₜ\)

  • cell_state - 当前批次初始细胞状态输入( \(c₀\)),执行后更新为最后时刻的细胞状态输出( \(cₜ\))。

  • buffer - 临时工作区指针数组(中间计算缓存,如门值、激活结果、临时矩阵等,用于优化性能)。

  • LstmParameter - LSTM 配置参数结构体,包含输入大小、隐藏层维度、序列长度、是否双向等信息。

  • core_mask - 核掩码(仅适用于共享存储版本)。

LstmParameter定义:

 1typedef struct LstmParameter {
 2int input_size_;//每个时间步输入向量的维度(输入特征数)。
 3int hidden_size_;//LSTM 隐藏状态的维度(每个门的内部计算大小)。
 4int project_size_;//投影层输出维度(用于 LSTMP,有则在输出前线性压缩隐藏状态)。
 5int output_size_;//实际输出维度,等于 hidden_size_ 或 project_size_(取决于是否使用投影层)。
 6int seq_len_;//输入序列的时间步数(序列长度)。
 7int batch_;//批次大小(一次处理的样本数量)。
 8// other parameter
 9int output_step_;//指定输出第几个时间步的结果(通常为最后一步或每步)。
10bool bidirectional_;//是否为双向 LSTM(true 表示前向和后向各一层)。
11float zoneout_cell_;//单元状态的 Zoneout 比例(防止过拟合的正则化参数)。
12float zoneout_hidden_;//隐藏状态的 Zoneout 比例(防止过拟合)。
13int input_row_align_;//输入张量的行对齐参数(用于 DMA 或 SIMD 加速的内存对齐)。
14int input_col_align_;//输入张量的列对齐参数。
15int state_row_align_;//状态张量(hidden/cell)的行对齐参数。
16int state_col_align_;//状态张量的列对齐参数。
17int proj_col_align_;//投影层矩阵的列对齐参数。
18bool has_bias_;//是否包含偏置项(true 表示使用 bias)。
19} LstmParameter;
输出:
  • output - 计算结果地址,存放 LSTM 每个时间步输出结果的缓冲区,维度通常为 \((seq\_len, batch, output\_size)\)

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持fp32

  • MT7004 支持fp32

共享存储版本:

void fp_Lstm_s(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, float *buffer[9], const LstmParameter *lstm_param, int core_mask)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <lstm.h>
 4
 5int main(int argc, char* argv[]) {
 6    LstmParameter *lstm_param = (LstmParameter *)0x90000000;
 7    lstm_param->seq_len_ = 20;
 8    lstm_param->batch_ = 1;
 9    lstm_param->input_size_ = 2000;
10    lstm_param->hidden_size_ = 3;
11    lstm_param->bidirectional_ = false;
12    float * input = (float *)0xA0000000;   //input在DDR空间
13    float * weight_i = (float *)0xA1000000;
14    float * weight_h = (float *)0xA3000000;
15    float *input_bias_ =(float *) 0xB0900000;
16    float * state_bias_ =(float *) 0xB0B00000;
17    float * output_s = (float *)0xC0000000;
18    float *hidden_state_s = (float *)0xC0100000;
19    float *cell_state_s = (float *)0xC0200000;
20    float *buffer[9];
21    float * packed_input_ = (float *)0xB0000000;
22    buffer[0] = packed_input_;
23    float * gate = (float *)0xB0100000;
24    buffer[1] = gate;
25    float * packed_state = (float *)0xB0200000;
26    buffer[2] = packed_state;
27    float * state_gate = (float *)0xB0300000;
28    buffer[3] = state_gate;
29    float * cell_buffer = (float *)0xB0400000;
30    buffer[4] = cell_buffer;
31    float * hidden_buffer = (float *)0xB0500000;
32    buffer[5] = hidden_buffer;
33    float * packed_output = (float *)0xB0600000;
34    buffer[6] = packed_output;
35    float * left_matrix = (float *)0xB0700000;
36    buffer[7] = left_matrix;
37    float * packed_ptr = (float *)0xB0800000;
38    buffer[8] = packed_ptr;
39    int core_mask = 0xff;
40    fp_Lstm_s(output_s, input, weight_i, weight_h, input_bias_,
41    state_bias, hidden_state_s, cell_state_s, buffer,
42    lstm_param, core_mask);
43    return 0;
44}

私有存储版本:

void fp_Lstm_p(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, float *buffer[9], const LstmParameter *lstm_param)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <lstm.h>
 4int main(int argc, char* argv[]) {
 5    LstmParameter *lstm_param = (LstmParameter *)0x10000000;
 6    lstm_param->seq_len_ = 4;
 7    lstm_param->batch_ = 1;
 8    lstm_param->input_size_ = 2;
 9    lstm_param->hidden_size_ = 3;
10    lstm_param->bidirectional_ = false;
11    float * input = (float *)0x10000200;   //input在DDR空间
12    float * weight_i = (float *)0x10000400;
13    float * weight_h = (float *)0x10000600;
14    float *input_bias_ =(float *) 0x10000800;
15    float * state_bias_ =(float *) 0x10000A00;
16    float * output_s = (float *)0x10000C00;
17    float *hidden_state_s = (float *)0x10000E00;
18    float *cell_state_s = (float *)0x10001000;
19    float *buffer[9];
20    float * packed_input_ = (float *)0x10001200;
21    buffer[0] = packed_input_;
22    float * gate = (float *)0x10001400;
23    buffer[1] = gate;
24    float * packed_state = (float *)0x10001600;
25    buffer[2] = packed_state;
26    float * state_gate = (float *)0x10001800;
27    buffer[3] = state_gate;
28    float * cell_buffer = (float *)0x10001A00;
29    buffer[4] = cell_buffer;
30    float * hidden_buffer = (float *)0x10001C00;
31    buffer[5] = hidden_buffer;
32    float * packed_output = (float *)0x10001F00;
33    buffer[6] = packed_output;
34    float * left_matrix = (float *)0x10002000;
35    buffer[7] = left_matrix;
36    float * packed_ptr = (float *)0x10002200;
37    buffer[8] = packed_ptr;
38    fp_Lstm_p(output_s, input, weight_i, weight_h, input_bias_,
39    state_bias, hidden_state_s, cell_state_s, buffer,
40    lstm_param);
41    return 0;
42}